package org.springframework.security.oauth2.client.token;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import javax.sql.DataSource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.support.SqlLobValue;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.util.SerializationUtils;
import org.springframework.util.Assert;
/**
* Implementation of token services that stores tokens in a database for retrieval by client applications.
*
* @author Dave Syer
*/
public class JdbcClientTokenServices implements ClientTokenServices {
private static final Log LOG = LogFactory.getLog(JdbcClientTokenServices.class);
private static final String DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT = "insert into oauth_client_token (token_id, token, authentication_id, user_name, client_id) values (?, ?, ?, ?, ?)";
private static final String DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT = "select token_id, token from oauth_client_token where authentication_id = ?";
private static final String DEFAULT_ACCESS_TOKEN_DELETE_STATEMENT = "delete from oauth_client_token where authentication_id = ?";
private String insertAccessTokenSql = DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT;
private String selectAccessTokenSql = DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT;
private String deleteAccessTokenSql = DEFAULT_ACCESS_TOKEN_DELETE_STATEMENT;
private ClientKeyGenerator keyGenerator = new DefaultClientKeyGenerator();
private final JdbcTemplate jdbcTemplate;
public JdbcClientTokenServices(DataSource dataSource) {
Assert.notNull(dataSource, "DataSource required");
this.jdbcTemplate = new JdbcTemplate(dataSource);
}
public void setClientKeyGenerator(ClientKeyGenerator keyGenerator) {
this.keyGenerator = keyGenerator;
}
public OAuth2AccessToken getAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) {
OAuth2AccessToken accessToken = null;
try {
accessToken = jdbcTemplate.queryForObject(selectAccessTokenSql, new RowMapper<OAuth2AccessToken>() {
public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
return SerializationUtils.deserialize(rs.getBytes(2));
}
}, keyGenerator.extractKey(resource, authentication));
}
catch (EmptyResultDataAccessException e) {
if (LOG.isInfoEnabled()) {
LOG.debug("Failed to find access token for authentication " + authentication);
}
}
return accessToken;
}
public void saveAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication,
OAuth2AccessToken accessToken) {
removeAccessToken(resource, authentication);
String name = authentication==null ? null : authentication.getName();
jdbcTemplate.update(
insertAccessTokenSql,
new Object[] { accessToken.getValue(), new SqlLobValue(SerializationUtils.serialize(accessToken)),
keyGenerator.extractKey(resource, authentication), name,
resource.getClientId() }, new int[] { Types.VARCHAR, Types.BLOB, Types.VARCHAR, Types.VARCHAR,
Types.VARCHAR });
}
public void removeAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) {
jdbcTemplate.update(deleteAccessTokenSql, keyGenerator.extractKey(resource, authentication));
}
public void setInsertAccessTokenSql(String insertAccessTokenSql) {
this.insertAccessTokenSql = insertAccessTokenSql;
}
public void setSelectAccessTokenSql(String selectAccessTokenSql) {
this.selectAccessTokenSql = selectAccessTokenSql;
}
public void setDeleteAccessTokenSql(String deleteAccessTokenSql) {
this.deleteAccessTokenSql = deleteAccessTokenSql;
}
}